local promise         = require('promise')
local async           = require('async')
local helpers         = require('spec.helpers.init')
local compat          = require('promise-async.compat')
local deferredPromise = helpers.deferredPromise
local setTimeout      = helpers.setTimeout
local basics          = require('spec.helpers.basics')
local dummy           = {dummy = 'dummy'}
local sentinel        = {sentinel = 'sentinel'}
local sentinel2       = {sentinel = 'sentinel2'}
local sentinel3       = {sentinel = 'sentinel3'}
local other           = {other = 'other'}

describe('async await module.', function()
    describe('async return a Promise.', function()
        it('return value is a Promise', function()
            local f = async(function() end)
            assert.True(promise.isInstance(f))
        end)

        it('async without return statement', function()
            async(function() end)
                :thenCall(function(value)
                    assert.equal(nil, value)
                    done()
                end)
            assert.True(wait())
        end)

        it('async return a single', function()
            async(function()
                return dummy
            end):thenCall(function(value)
                assert.equal(dummy, value)
                done()
            end)
            assert.True(wait())
        end)

        it('async return multiple values, which are packed into resolved result in Promise', function()
            async(function()
                return sentinel, sentinel2, sentinel3
            end):thenCall(function(value)
                assert.same({sentinel, sentinel2, sentinel3}, value)
                done()
            end)
            assert.True(wait())
        end)

        it('async throw error', function()
            async(function()
                error(dummy)
                return other
            end):thenCall(nil, function(reason)
                assert.equal(dummy, reason)
                done()
            end)
            assert.True(wait())
        end)
    end)

    describe('executor inside async.', function()
        describe('must be either a function or a callable table,', function()
            it('other values', function()
                local errorValues = {nil, 0, '0', true}
                for _, v in pairs(errorValues) do
                    assert.error(function()
                        async(v)
                    end)
                end
            end)

            it('a function', function()
                async(function()
                    setTimeout(function()
                        done()
                    end, 10)
                end)
                assert.True(wait())
            end)

            it('a callable table', function()
                async(setmetatable({}, {
                    __call = function()
                        setTimeout(function()
                            done()
                        end, 10)
                    end
                }))
                assert.True(wait())
            end)
        end)

        it('should run immediately', function()
            local executor = spy.new(function() end)
            async(executor)
            assert.spy(executor).was_called()
        end)

        it('until the await is called, executor should run immediately even if in nested function', function()
            local value
            local executor = spy.new(function() end)
            async(function()
                value = async.wait(async(function()
                    return async(function()
                        executor()
                        return dummy
                    end)
                end))
                done()
            end)
            assert.spy(executor).was_called()
            assert.True(wait())
            assert.equal(dummy, value)
        end)
    end)

    describe([[await inside async's executor.]], function()
        local function testBasicAwait(expectedValue, stringRepresentation)
            it('should wait for the promise with resolved value: ' .. stringRepresentation, function()
                local value
                local p, resolve = deferredPromise()
                async(function()
                    value = await(p)
                    done()
                end)
                assert.False(wait(10))
                resolve(expectedValue)
                assert.True(wait())
                assert.equal(expectedValue, value)
            end)
        end

        for valueStr, basicFn in pairs(basics) do
            testBasicAwait(basicFn(), valueStr)
        end
    end)

    describe('`pcall` and `xpcall` surround statement or function.', function()
        it('call `pcall` to get the value from a single await', function()
            local ok, value
            local p, resolve = deferredPromise()
            async(function()
                ok, value = pcall(await, p)
                done()
            end)
            assert.False(wait(10))
            resolve(dummy)
            assert.True(wait())
            assert.True(ok)
            assert.equal(dummy, value)
        end)

        it('call `xpcall` to get the value from a single await', function()
            local ok, value
            local p, resolve = deferredPromise()
            async(function()
                ok, value = xpcall(await, function() end, p)
                done()
            end)
            assert.False(wait(10))
            resolve(dummy)
            assert.True(wait())
            assert.True(ok)
            assert.equal(dummy, value)
        end)

        it('call `pcall` to catch the reason from a single await', function()
            local ok, reason
            local p, _, reject = deferredPromise()
            async(function()
                ok, reason = pcall(await, p)
                done()
            end)
            assert.False(wait(10))
            reject(dummy)
            assert.True(wait())
            assert.False(ok)
            assert.equal(dummy, reason)
        end)

        it('call `xpcall` to catch the reason from a single await', function()
            local ok, reason
            local p, _, reject = deferredPromise()
            async(function()
                ok = xpcall(await, function(e)
                    reason = e
                end, p)
                done()
            end)
            assert.False(wait(10))
            reject(dummy)
            assert.True(wait())
            assert.False(ok)
            assert.equal(dummy, reason)
        end)

        describe('call `pcall` to catch the reason from a function,', function()
            it('throw error after the result return by await', function()
                local ok, value, reason
                local p1, resolve = deferredPromise()
                local p2, _, reject = deferredPromise()
                async(function()
                    ok, reason = pcall(function()
                        value = await(p1)
                        await(p2)
                    end)
                    done()
                end)
                p1:thenCall(function()
                    reject(dummy)
                end)
                assert.False(wait(10))
                setTimeout(function()
                    resolve(other)
                end, 20)
                assert.True(wait())
                assert.False(ok)
                assert.equal(other, value)
                assert.equal(dummy, reason)
            end)

            it('throw error before the result return by await', function()
                local ok, value, reason
                local p1, _, reject = deferredPromise()
                local p2, resolve = deferredPromise()
                async(function()
                    ok, reason = pcall(function()
                        await(p1)
                        value = await(p2)
                    end)
                    done()
                end)
                resolve(other)
                assert.False(wait(10))
                setTimeout(function()
                    reject(dummy)
                end, 20)
                assert.True(wait())
                assert.False(ok)
                assert.equal(nil, value)
                assert.equal(dummy, reason)
            end)
        end)

        describe('call `xpcall` to catch the reason from a function,', function()
            it('throw error after the result return by await', function()
                local ok, value, reason
                local p1, resolve = deferredPromise()
                local p2, _, reject = deferredPromise()
                async(function()
                    ok = xpcall(function()
                        value = await(p1)
                        await(p2)
                    end, function(e)
                        reason = e
                    end)
                    done()
                end)
                p1:thenCall(function()
                    reject(dummy)
                end)
                assert.False(wait(10))
                setTimeout(function()
                    resolve(other)
                end, 20)
                assert.True(wait())
                assert.False(ok)
                assert.equal(other, value)
                assert.equal(dummy, reason)
            end)

            it('throw error before the result return by await', function()
                local ok, value, reason
                local p1, _, reject = deferredPromise()
                local p2, resolve = deferredPromise()
                async(function()
                    ok = xpcall(function()
                        await(p1)
                        value = await(p2)
                    end, function(e)
                        reason = e
                    end)
                    done()
                end)
                resolve(other)
                assert.False(wait(10))
                setTimeout(function()
                    reject(dummy)
                end, 20)
                assert.True(wait())
                assert.False(ok)
                assert.equal(nil, value)
                assert.equal(dummy, reason)
            end)
        end)
    end)

    describe('nested async functions.', function()
        it('simple call chain', function()
            local value
            async(function()
                value = await(async(function()
                    return async(function()
                        return dummy
                    end)
                end))
                done()
            end)
            assert.True(wait())
            assert.equal(dummy, value)
        end)

        it('deferred call', function()
            local value
            setTimeout(function()
                async(function()
                    value = await(async(function()
                        return async(function()
                            return sentinel
                        end)
                    end))
                    done()
                end)
            end, 10)
            assert.True(wait())
            assert.equal(sentinel, value)
        end)

        it('return multiple values', function()
            local value1, value2, value3, value4
            async(function()
                value1, value2, value3, value4 = await(async(function()
                    return sentinel, sentinel2, sentinel3
                end))
                done()
            end)
            assert.True(wait())
            assert.equal(sentinel, value1)
            assert.equal(sentinel2, value2)
            assert.equal(sentinel3, value3)
            assert.equal(nil, value4)
        end)

        it('should catch error from the deepest callee', function()
            local ok, value, reason
            async(function()
                ok = xpcall(function()
                    value = await(async(function()
                        return async(function()
                            error(dummy)
                            return other
                        end)
                    end))
                end, function(e)
                    reason = e
                end)
                done()
            end)
            assert.True(wait())
            assert.False(ok)
            assert.equal(nil, value)
            assert.equal(dummy, reason)
        end)

        it('should wait for the deepest callee', function()
            local value
            local p, resolve = deferredPromise()
            async(function()
                value = await(async(function()
                    return await(async(function()
                        return await(p)
                    end))
                end))
                done()
            end)
            resolve(dummy)
            assert.True(wait())
            assert.equal(dummy, value)
        end)

        it('should wait for all callees', function()
            local value
            local p1, resolve1 = deferredPromise()
            local p2, resolve2 = deferredPromise()
            local p3, resolve3 = deferredPromise()

            async(function()
                value = compat.pack(await(async(function()
                    return await(p1), await(async(function()
                        return await(p2), await(p3)
                    end))
                end)))
                done()
            end)
            assert.False(wait(10))
            resolve3(sentinel3)
            assert.False(wait(10))
            resolve2(sentinel2)
            assert.False(wait(10))
            resolve1(sentinel)
            assert.True(wait())
            assert.same({sentinel, sentinel2, sentinel3, n = 3}, value)
        end)
    end)
end)
